# -*- coding: utf-8 -*-
"""
Backbones of Pre-training Models (from input to last hidden-layer output)
"""

import torch
import torch.nn as nn
from torch_scatter import scatter_max, scatter_mean

import qa.tuta.model.embeddings as emb
import qa.tuta.model.encoders as enc
from qa.tuta.preprocess import TUTAPreprocessor
from qa.table_bert.config import BertConfig
from transformers.models.bert.tokenization_bert import BertTokenizer
from qa.tuta.tokenizer import *
from qa.datadump.utils import *


class Backbone(nn.Module):
    def __init__(self, config):
        super(Backbone, self).__init__()
        self.total_node = sum(config.node_degree)
        self.attn_method = config.attn_method
        self.attn_methods = {"max": self.pos2attn_max,
                             "add": self.pos2attn_add}

    def unzip_tree_position(self, zipped_position):  # *dist*0/1/2
        """
        args: zipped_position: [batch_size, seq_len, tree_depth], range: [0, total_node]
        rets: entire_position: [batch_size, seq_len, total_node]
        lower_bound = 0, upper_bound = (total_node-1)
        use one excessive bit to temporarily represent not-applicable nodes
        """
        batch_size, seq_len, _ = zipped_position.size()
        entire_position = torch.zeros(batch_size, seq_len, self.total_node + 1).to(zipped_position.device)
        entire_position = entire_position.scatter_(-1, zipped_position, 1.0).long()
        entire_position = entire_position[:, :, : self.total_node]  # remove last column
        return entire_position

    def get_attention_mask(self, entire_top, entire_left, indicator):
        # attention_mask = self.attn_methods[self.attn_method](entire_top, entire_left)  # TODO!
        b, seq_len = indicator.size()
        attention_mask = torch.zeros((b, seq_len, seq_len)).to(indicator.device)
        attention_mask = self.create_post_mask(attention_mask, indicator)
        return attention_mask

    def pos2attn_max(self, pos_top, pos_left):  # entire position
        top_attn_mask = self.pos2attn(pos_top)
        left_attn_mask = self.pos2attn(pos_left)
        attn_mask = torch.max(top_attn_mask, left_attn_mask)
        # attn_mask = top_attn_mask + left_attn_mask
        return attn_mask

    def pos2attn_add(self, pos_top, pos_left):  # entire position
        top_attn_mask = self.pos2attn(pos_top)
        left_attn_mask = self.pos2attn(pos_left)
        attn_mask = top_attn_mask + left_attn_mask
        return attn_mask

    def pos2attn(self, position):  # entire position
        """Compute a one-dimension attention distance matrix from a entire-mode tree position.
        Args:
            position: (b, seq_len, total_node(384))
        """
        vector_matrix = position.unsqueeze(2).repeat(1, 1, position.size()[1],
                                                     1)  # [batch, seq_len, seq_len, total_node]
        attention_mask = torch.abs(vector_matrix - vector_matrix.transpose(1, 2))
        attention_mask = torch.sum(attention_mask, dim=-1)
        return attention_mask

    def create_post_mask(self, attn_dist, indicator, padding_dist=100):
        """
        [CLS] sees all of the tokens except for the [PAD]s
        [SEP]s in table see each other & their own cells; [SEP]s in clc/tcr choices see as their tokens
        Tokens see their friend and corresponding [SEP]
        """
        # torch.set_printoptions(profile="full")
        # cls_matrix = (indicator == -1).long().unsqueeze(-1).repeat(1, 1, attn_dist.size(1))
        # cls_matrix = torch.max(cls_matrix, cls_matrix.transpose(-1, -2))
        # cls_matrix = -(cls_matrix * attn_dist)  # *cls attn*dist*
        # pad_matrix = (indicator == 0).long().unsqueeze(-1).repeat(1, 1, attn_dist.size(1))
        # pad_matrix = torch.max(pad_matrix, pad_matrix.transpose(-1, -2)) * padding_dist
        # # print('attn_dist')
        # # print(attn_dist[0][:32])
        # # print('cls')
        # # print(cls_matrix[0][:32])
        # # print('pad')
        # # print(pad_matrix[0][:32])
        # attn_dist = attn_dist + cls_matrix + pad_matrix
        #
        # # only table-[SEP]s and root can see their contexts
        # sep_matrix = (indicator > 0).long() * (indicator % 2 == 1).long()
        # sep_matrix = sep_matrix.unsqueeze(-1).repeat(1, 1, attn_dist.size(1))
        # sep_matrix = (1 - sep_matrix * sep_matrix.transpose(1, 2)) * padding_dist
        # attn_dist = attn_dist * (sep_matrix + 1)
        # TODO!
        torch.set_printoptions(profile="full")
        pad_matrix = (indicator == 0).long().unsqueeze(-1).repeat(1, 1, attn_dist.size(1))
        pad_matrix = torch.max(pad_matrix, pad_matrix.transpose(-1, -2)) * padding_dist
        attn_dist = attn_dist + pad_matrix
        return attn_dist


class BbForBase(Backbone):
    def __init__(self, config):
        super(Backbone, self).__init__()
        self.embeddings = emb.EmbeddingForBase(config)
        self.encoder = enc.Encoder(config)
        self.attn_methods = {"max": self.pos2attn_max,
                             "add": self.pos2attn_add}
        self.attn_method = config.attn_method
        self.total_node = sum(config.node_degree)

    def forward(self, token_id, num_mag, num_pre, num_top, num_low, token_order, pos_top, pos_left, format_vec,
                indicator):
        embedded_states = self.embeddings(token_id, num_mag, num_pre, num_top, num_low, token_order, format_vec)
        entire_pos_top = self.unzip_tree_position(pos_top)
        entire_pos_left = self.unzip_tree_position(pos_left)
        attn_mask = self.get_attention_mask(entire_pos_top, entire_pos_left, indicator)
        encoded_states = self.encoder(embedded_states, attn_mask)
        return encoded_states


class BbForTutaExplicit(Backbone):
    def __init__(self, config):
        super(Backbone, self).__init__()
        self.embeddings = emb.EmbeddingForTutaExplicit(config)
        self.encoder = enc.Encoder(config)
        self.attn_methods = {"max": self.pos2attn_max,
                             "add": self.pos2attn_add}
        self.attn_method = config.attn_method
        self.total_node = sum(config.node_degree)

    def forward(self,
                token_id, num_mag, num_pre, num_top, num_low,
                token_order, pos_row, pos_col, pos_top, pos_left,
                format_vec, indicator
                ):
        entire_pos_top = self.unzip_tree_position(pos_top)
        entire_pos_left = self.unzip_tree_position(pos_left)
        embedded_states = self.embeddings(
            token_id, num_mag, num_pre, num_top, num_low,
            token_order, pos_row, pos_col, entire_pos_top, entire_pos_left, format_vec
        )
        attn_mask = self.get_attention_mask(entire_pos_top, entire_pos_left, indicator)
        encoded_states = self.encoder(embedded_states, attn_mask)
        return encoded_states


class BbForTuta(Backbone):
    def __init__(self, config):
        super(Backbone, self).__init__()
        self.embeddings = emb.EmbeddingForTuta(config)
        self.encoder = enc.Encoder(config)
        self.attn_methods = {"max": self.pos2attn_max,
                             "add": self.pos2attn_add}
        self.attn_method = config.attn_method
        self.total_node = sum(config.node_degree)
        self.config = config

        self.preprocessor = TUTAPreprocessor(self.config)

        # table bert
        self.output_size = config.hidden_size
        self.bert_config = BertConfig(
            vocab_size_or_config_json_file=config.vocab_size,
            attention_probs_dropout_prob=0.1,
            hidden_act=config.hidden_act,
            hidden_dropout_prob=config.hidden_dropout_prob,
            hidden_size=config.hidden_size,
            initializer_range=0.02,
            intermediate_size=config.intermediate_size,
            # layer_norm_eps=1e-12,
            max_position_embeddings=512,
            num_attention_heads=config.num_attention_heads,
            num_hidden_layers=12,
            type_vocab_size=2,
        )
        # self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')  # only for tokenizing question
        self.tokenizer = HMTTokenizer(config)

    def forward(self,
                token_id, num_mag, num_pre, num_top, num_low,
                token_order, pos_row, pos_col, pos_top, pos_left,
                format_vec, indicator, position_ids, segment_ids,
                ):
        embedded_states = self.embeddings(
            token_id, num_mag, num_pre, num_top, num_low,
            token_order, pos_row, pos_col, pos_top, pos_left, format_vec, position_ids, segment_ids
        )
        entire_pos_top = self.unzip_tree_position(pos_top)
        entire_pos_left = self.unzip_tree_position(pos_left)
        attn_mask = self.get_attention_mask(entire_pos_top, entire_pos_left, indicator)
        encoded_states = self.encoder(embedded_states, attn_mask)
        return encoded_states

    def encode(self, contexts, tables):
        """ High-level encode, including data preprocessing for TUTA and forward()."""
        batch_data, mask_info = self.preprocessor.pipeline(contexts, tables)
        device = next(self.embeddings.parameters()).device if torch.cuda.is_available() else 'cpu'
        # device = 'cuda' if torch.cuda.is_available() else 'cpu'
        batch_input = []
        for i in range(len(batch_data)):
            batch_input.append(batch_data[i].to(device))
        for k in mask_info:
            mask_info[k] = mask_info[k].to(device)
        encoded_states = self.forward(*batch_input)  # 2 * 464 * 768
        question_encoding = self.get_context_representation(encoded_states,
                                                      mask_info['context_token_indices'],
                                                      mask_info['context_token_mask'])
        header_encoding = self.get_header_representation(encoded_states,
                                                   mask_info['header_token_indices'],
                                                   mask_info['header_mask'])
        index_name_encoding = self.get_index_name_representation(encoded_states,
                                                           mask_info['level_token_indices'],
                                                           mask_info['level_mask'])
        # print(f"question: {question_encoding.size()}")
        # print(f"header: {header_encoding.size()}")
        # print(f"index_name: {index_name_encoding.size()}")
        return question_encoding, header_encoding, index_name_encoding, {'tensor_dict': mask_info}

    def get_context_representation(self, encoded_states, context_token_indices, context_mask):
        """ Extract context encoding from bert output."""
        context_encoding = torch.gather(  # context embedding
            encoded_states,
            dim=1,
            index=context_token_indices.unsqueeze(-1).expand(-1, -1, encoded_states.size(-1)),
        )
        context_encoding = context_encoding * context_mask.unsqueeze(-1)
        return context_encoding

    @staticmethod
    def get_index_name_representation(
            flattened_index_name_encoding: torch.Tensor,
            level_token_indices: torch.Tensor,
            level_mask: torch.Tensor,
            encode_field: str = 'level',
            aggregator: str = 'mean_pool'
    ):
        """ Aggregate encoding of each index name according to index_name_indices/index_name_mask."""
        if encode_field == 'level':
            token_indices = level_token_indices
            mask = level_mask
        else:
            raise ValueError(f"Unknown encode_field. select from ['level'|'index_name']")

        if aggregator.startswith('max_pool'):
            agg_func = scatter_max
            flattened_index_name_encoding[mask == 0] = float('-inf')
        elif aggregator.startswith('mean_pool') or aggregator.startswith('first_token'):
            agg_func = scatter_mean
        else:
            raise ValueError(f'Unknown index name representation method {aggregator}')

        max_level_num = mask.size(-1)
        # column_token_to_column_id: (batch_size, max_column_num)
        # (batch_size, max_column_size + 1, encoding_size)
        result = agg_func(flattened_index_name_encoding,  # src
                          token_indices.unsqueeze(-1).expand(-1, -1, flattened_index_name_encoding.size(-1)),  # idx
                          dim=1,
                          dim_size=max_level_num + 1)

        # remove the last "garbage collection" entry, mask out padding columns
        result = result[:, :-1] * mask.unsqueeze(-1)  # (b, col_num, enc)=(b, 8, 768)

        if aggregator == 'max_pool':  # why the first?
            header_encoding = result[0]
        else:
            header_encoding = result

        return header_encoding

    @staticmethod
    def get_header_representation(
        flattened_header_encoding: torch.Tensor,
        header_token_indices: torch.Tensor,
        header_mask: torch.Tensor,
        aggregator: str = 'mean_pool'
    ):
        """ Aggregate encoding of each header according to header_token_indices/header_mask."""
        if aggregator.startswith('max_pool'):
            agg_func = scatter_max
            flattened_header_encoding[header_mask == 0] = float('-inf')  # FIXME: bug
        elif aggregator.startswith('mean_pool') or aggregator.startswith('first_token'):
            agg_func = scatter_mean
        else:
            raise ValueError(f'Unknown header representation method {aggregator}')

        max_header_num = header_mask.size(-1)
        # column_token_to_column_id: (batch_size, max_column_num)
        # (batch_size, max_column_size + 1, encoding_size)
        try:
            minv, maxv = torch.min(header_token_indices).detach().cpu(), torch.max(header_token_indices).detach().cpu()
            result = agg_func(flattened_header_encoding,  # src
                              header_token_indices.unsqueeze(-1).expand(-1, -1, flattened_header_encoding.size(-1)),  # idx
                              dim=1,
                              dim_size=max_header_num + 1)
            # print("normal: ", minv, maxv, header_mask.size())
        except:
            print(flattened_header_encoding.size())
            print(minv, maxv)
            print(header_token_indices.size())
            print(header_mask.size())

        # remove the last "garbage collection" entry, mask out padding columns
        result = result[:, :-1] * header_mask.unsqueeze(-1)  # (b, col_num, enc)=(b, 8, 768)

        if aggregator == 'max_pool':  # why the first?
            header_encoding = result[0]  # FIXME: bug
        else:
            header_encoding = result

        return header_encoding



BACKBONES = {
    "tuta": BbForTuta,
    "base": BbForBase,
    "tuta_explicit": BbForTutaExplicit
}
